import difflib
import xml.etree.ElementTree as ET
import os


# 批量修改整个文件夹所有的xml文件
def change_all_xml(xml_path):
    filelist = os.listdir(xml_path)
    print(len(filelist))
    # print(filelist)
    # 打开xml文档
    length = len(filelist)
    for idx in range(length):
        xmlfile = filelist[idx]
        doc = ET.parse(os.path.join(xml_path, xmlfile))
        # print('file_name:', xmlfile)
        root = doc.getroot()
        objects = root.findall('object')
        # print(objects)
        for idx in range(len(objects)):
            _object = objects[idx]
            # print(_object.tag)
            sub1 = _object.find('name')
            name = sub1.text
            # print('label:', sub1.text)
            if name not in label_list:
                matches = difflib.get_close_matches(name, label_list)  # 寻找最接近的label
                doc.getroot().findall('object')[idx].find('name').text = matches[0]
                print('file:', xmlfile)
                print('fix {} to {}', name, matches[0])
                doc.write(os.path.join(xml_path, xmlfile))
                # root.findall('object')[0].find('name').text = 'mixed conge'




if __name__ == '__main__':
    # xml_path = r'D:\images\dataset_origin2\13_finished\xml'
    root = r'dataset'
    xml_path = os.path.join(root, 'train', 'Annotations')

    # 解析label_list文件
    with open(os.path.join(root, "label_list.txt"), 'r') as file:
        label_list = file.readlines()
    label_list = [label.rstrip() for label in label_list]  # 去掉空字符
    label_list = [label for label in label_list if label != '']  # 去掉空行

    change_all_xml(xml_path)
